Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI][GPU] Mergepath sort with odd-even block sort #7611

Merged
merged 12 commits into from
Mar 16, 2021

Conversation

mbrookhart
Copy link
Contributor

@mbrookhart mbrookhart commented Mar 8, 2021

Building on @masahi's work to add a while loop to TIR, this PR implements the MergePath algorithm to parallelize the later stages of mergesort. It also implements a stable in-shared-memory odd-even sort to do small block sorting before mergesort for better speed. All told, this optimization gives us dramatic speedups on both AMD and Nvidia GPUs and trades blows with our thrust implementation for many shapes, see below. cc @Laurawly @zhiics @icemelon9 @csullivan @tkonolige

I had to add a path for skipping thread planning to get handed-threaded odd-even-tranpose sort to work, cc @tqchen @junrushao1994

                   thrust_R9-Nano  mergepath_R9-Nano  thrust_1070Ti   mergepath_1070Ti     thrust_V100  mergepath_V100
[1000, 1, 1]              0.60 ms            0.15 ms        0.05 ms            0.05 ms         0.07 ms         0.04 ms
[1, 1000, 1]              0.59 ms            0.14 ms        0.05 ms            0.05 ms         0.07 ms         0.04 ms
[1, 1, 1000]              0.71 ms            0.13 ms        0.04 ms            0.05 ms         0.06 ms         0.03 ms
[1000, 10, 10]            1.65 ms            0.25 ms        0.96 ms            0.27 ms         1.81 ms         0.08 ms
[10, 1000, 10]            1.65 ms            0.25 ms        0.80 ms            0.26 ms         1.64 ms         0.08 ms
[10, 10, 1000]            1.63 ms            0.23 ms        0.95 ms            0.25 ms         1.82 ms         0.07 ms
[1000, 100, 100]         29.29 ms           18.59 ms       47.15 ms           33.82 ms        18.74 ms         7.43 ms
[100, 1000, 100]         28.63 ms           16.35 ms       37.36 ms           23.57 ms        17.44 ms         5.28 ms
[100, 100, 1000]         29.62 ms           15.13 ms       36.33 ms           21.11 ms        18.04 ms         4.64 ms
[10000, 1, 1]             0.78 ms            0.25 ms        0.14 ms            0.10 ms         0.46 ms         0.08 ms
[1, 10000, 1]             0.79 ms            0.25 ms        0.14 ms            0.10 ms         0.45 ms         0.08 ms
[1, 1, 10000]             0.76 ms            0.24 ms        0.13 ms            0.09 ms         0.45 ms         0.08 ms
[10000, 10, 10]           5.27 ms            2.68 ms        4.58 ms            3.36 ms         3.78 ms         0.78 ms
[10, 10000, 10]           5.03 ms            2.56 ms        4.47 ms            3.33 ms         3.78 ms         0.73 ms
[10, 10, 10000]           5.34 ms            2.48 ms        4.34 ms            3.08 ms         3.76 ms         0.70 ms
[10000, 100, 100]   Out of memory          338.65 ms      467.08 ms          437.17 ms       145.14 ms        96.20 ms
[100, 10000, 100]   Out of memory          284.24 ms      410.39 ms          380.18 ms       134.70 ms        84.63 ms
[100, 100, 10000]   Out of memory          250.24 ms      367.42 ms          305.95 ms       111.16 ms        59.69 ms
[100000, 1, 1]            0.73 ms            0.51 ms        0.16 ms            0.40 ms         0.50 ms         0.14 ms
[1, 100000, 1]            0.73 ms            0.51 ms        0.27 ms            0.40 ms         0.66 ms         0.14 ms
[1, 1, 100000]            0.71 ms            0.50 ms        0.23 ms            0.39 ms         0.69 ms         0.13 ms
[100000, 10, 10]         26.64 ms           33.12 ms       39.23 ms           39.64 ms        18.65 ms         8.48 ms
[10, 100000, 10]         26.80 ms           30.91 ms       37.43 ms           37.88 ms        17.74 ms         7.67 ms
[10, 10, 100000]         25.46 ms           29.75 ms       36.33 ms           35.59 ms        16.49 ms         7.17 ms
[1000000, 1, 1]           1.91 ms            3.68 ms        1.18 ms            4.36 ms         1.00 ms         1.01 ms
[1, 1000000, 1]           1.91 ms            3.68 ms        1.23 ms            4.32 ms         1.00 ms         1.01 ms
[1, 1, 1000000]           1.91 ms            3.64 ms        1.15 ms            4.34 ms         1.01 ms         0.99 ms
[1000000, 10, 10]   Out of memory          449.20 ms      410.63 ms          499.69 ms       136.20 ms       115.64 ms
[10, 1000000, 10]   Out of memory          385.10 ms      366.33 ms          443.03 ms       116.31 ms        94.27 ms
[10, 10, 1000000]   Out of memory          370.27 ms      364.14 ms          419.95 ms      111.16 ms         87.82 ms
[4507]                    0.70 ms            0.21 ms        0.14 ms            0.07 ms         0.46 ms         0.06 ms
[1, 122640]               0.78 ms            0.53 ms        0.17 ms            0.45 ms         0.70 ms         0.14 ms
[1, 120000]               0.73 ms            0.52 ms        0.28 ms            0.44 ms         0.69 ms         0.14 ms
[1, 30000]                0.74 ms            0.28 ms        0.16 ms            0.14 ms         0.46 ms         0.09 ms
[1, 7500]                 0.73 ms            0.21 ms        0.13 ms            0.08 ms         0.46 ms         0.07 ms
[1, 1000]                 0.65 ms            0.13 ms        0.04 ms            0.04 ms         0.07 ms         0.04 ms

@mbrookhart mbrookhart changed the title Mergepath sort with odd-even block sort [TOPI][GPU] Mergepath sort with odd-even block sort Mar 8, 2021
@masahi
Copy link
Member

masahi commented Mar 8, 2021

@mbrookhart This doesn't work on VK/SPIR-V because SPIR-V requires a thread block size to be a compile time constant, see

ICHECK(sizeptr) << "SPIRV only allows constant thread group size "

In particular, the following size selection has an issue, because width is not a constant:

ntx = tvm.tir.generic.cast(tvm.te.min(max_threads, width), "int32")
nbx = tvm.tir.generic.cast(ceil_div(width, max_threads * thread_work), "int32")
nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32")

So to workaround this problem, can we either

  • Use old TIR sort for VK
  • (Better) Use a constant thread block size for VK

?

@tkonolige
Copy link
Contributor

Can we just run the full thread block size and nop on the extra threads?

## Perhaps we can autotune?
block_size = 128
thread_work = 4

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @merrymercy can we use autotvm for this? i.e. a traditional auto tuner way?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any tips on how to do it?

python/tvm/topi/cuda/sort.py Outdated Show resolved Hide resolved
python/tvm/topi/cuda/sort.py Outdated Show resolved Hide resolved
python/tvm/topi/cuda/sort.py Outdated Show resolved Hide resolved
@masahi
Copy link
Member

masahi commented Mar 9, 2021

@mbrookhart Can you comment on the difference in the way two-level merge path is done, between your implementation and moderngpu https://github.com/moderngpu/moderngpu/blob/master/src/moderngpu/kernel_mergesort.hxx#L97?

moderngpu launches a separate, device-wide merge path kernel, before invoking block level merge path kenrel. Your code seems to launch two merge path from one kernel, is that right?

@masahi masahi self-assigned this Mar 9, 2021
@mbrookhart
Copy link
Contributor Author

@masahi I was having a hard time getting partition information back out when I used a seperate kernel, mostly due to handling memory in IR builder. Running the binary search in the kernel that will eventual merge the partitions was easier to handle from a memory-passing sense. I could go back and try temporary allocations to see if the added complexity gets us any performance improvement. Still thinking about the threading for vulkan.

@masahi
Copy link
Member

masahi commented Mar 9, 2021

There is an issue with dynamic topk? I saw the same failure yesterday

@mbrookhart
Copy link
Contributor Author

I've got a dynamic topk failure, but I can't reproduce it locally, that test passes on my version of dependencies. Maybe I can run the CI docker locally

@mbrookhart
Copy link
Contributor Author

@masahi I pushed the change to vulkan threading, could you test that on your GPU?

I build the CI docker on my machine and attempting to reproduce the dynamic topk failure, I still was unable to. I'm not sure where to go next.

@masahi
Copy link
Member

masahi commented Mar 10, 2021

On VK, it fails on (122640, 1) workload. rocm works fine.

Traceback (most recent call last):
  File "vk_test.py", line 146, in <module>
    test_argsort()
  File "vk_test.py", line 94, in test_argsort
    verify_argsort((1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
  File "vk_test.py", line 86, in verify_argsort
    tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.astype(dtype), rtol=1e-5)
  File "/home/masa/projects/dev/tvm/python/tvm/testing.py", line 82, in assert_allclose
    np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, verbose=True)
  File "/home/masa/anaconda3/lib/python3.8/site-packages/numpy/testing/_private/utils.py", line 1532, in assert_allclose
    assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
  File "/home/masa/anaconda3/lib/python3.8/site-packages/numpy/testing/_private/utils.py", line 846, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=1e-05, atol=1e-07

Mismatched elements: 122633 / 122640 (100%)
Max absolute difference: 122388
Max relative difference: 84808.
 x: array([[116254,  14451, 111687, ...,  64552, 122405, 122011]], dtype=int32)
 y: array([[116254,  14451, 111687, ...,  14845, 100388,  53504]], dtype=int32)

@mbrookhart
Copy link
Contributor Author

I validated that that threading change passed tests when I applied it to cuda, it just hurt performance :( I wonder if I'm hitting another spirv codegen issue?

@masahi
Copy link
Member

masahi commented Mar 10, 2021

yeah the same fixed-size threading applied to rocm works fine, so this is definitely a spirv issue. I'll see what's going on.

@masahi
Copy link
Member

masahi commented Mar 10, 2021

Interesting! I did a binary search for which size fails on VK, size below 2048 works fine, but starting from 2049 it fails.

@masahi
Copy link
Member

masahi commented Mar 10, 2021

@mbrookhart Also on the failure from test_any.py: The last test case tests on 0 size input, and it fails both rocm and vulkan. On VK I get segfault, on rocm I get Floating point exception (core dumped). Maybe you want to check if your code is safe for empty input?

On the old sort IR, rocm works fine with empty input. VK still gives segfault, that's another story :)

@mbrookhart
Copy link
Contributor Author

@masahi that might be the point at which I switch from single level mergepath to dual level mergepath. I wonder if vulkan is having trouble with the extra control flow there.

@masahi
Copy link
Member

masahi commented Mar 10, 2021

interesting, is it possible to always force single level?

@masahi
Copy link
Member

masahi commented Mar 16, 2021

I confirmed that this sort also works with VK / SPIRV backend with the following great results:

                   main_vulkan mergepath_vk mergepath_rocm
[1000, 1, 1]           3.04 ms      0.90 ms        0.12 ms
[1, 1000, 1]           3.04 ms      0.88 ms        0.12 ms
[1, 1, 1000]           3.04 ms      0.73 ms        0.11 ms
[1000, 10, 10]         3.49 ms      0.91 ms        0.24 ms
[10, 1000, 10]         3.49 ms      0.91 ms        0.23 ms
[10, 10, 1000]         3.25 ms      0.73 ms        0.22 ms
[1000, 100, 100]     124.09 ms     14.34 ms       19.13 ms
[100, 1000, 100]     186.12 ms     12.03 ms       17.07 ms
[100, 100, 1000]     119.85 ms     10.88 ms       15.87 ms
[10000, 1, 1]         31.27 ms      1.21 ms        0.24 ms
[1, 10000, 1]         31.23 ms      1.18 ms        0.24 ms
[1, 1, 10000]         31.20 ms      1.03 ms        0.23 ms
[10000, 10, 10]       40.54 ms      3.53 ms        2.72 ms
[10, 10000, 10]       41.29 ms      3.43 ms        2.60 ms
[10, 10, 10000]       38.88 ms      3.28 ms        2.52 ms
[10000, 100, 100]   3006.17 ms    270.40 ms      340.12 ms
[100, 10000, 100]   3327.44 ms    207.98 ms      286.45 ms
[100, 100, 10000]   1949.95 ms    179.33 ms      252.66 ms
[100000, 1, 1]       259.56 ms      1.37 ms        0.49 ms
[1, 100000, 1]       259.52 ms      1.30 ms        0.50 ms
[1, 1, 100000]       259.65 ms      1.14 ms        0.49 ms
[100000, 10, 10]     538.90 ms     33.35 ms       34.05 ms
[10, 100000, 10]     514.35 ms     31.01 ms       31.86 ms
[10, 10, 100000]     396.63 ms     29.90 ms       30.78 ms
[1000000, 1, 1]     2336.85 ms      5.68 ms        3.87 ms
[1, 1000000, 1]     2346.90 ms      5.69 ms        3.87 ms
[1, 1, 1000000]     2337.73 ms      5.51 ms        3.83 ms
[1000000, 10, 10]   6797.32 ms    411.09 ms      451.99 ms
[10, 1000000, 10]   6087.86 ms    347.81 ms      387.86 ms
[10, 10, 1000000]   4416.72 ms    333.54 ms      372.60 ms
[(4507,), 0]          15.22 ms      0.87 ms        0.20 ms
[(1, 122640), 1]     290.39 ms      1.18 ms        0.51 ms
[(1, 120000), 1]     285.99 ms      1.14 ms        0.51 ms
[(1, 30000), 1]       72.42 ms      0.98 ms        0.26 ms
[(1, 7500), 1]        18.76 ms      0.83 ms        0.19 ms
[(1, 1000), 1]         3.00 ms      0.68 ms        0.11 ms

@masahi masahi merged commit d288bbc into apache:main Mar 16, 2021
@masahi
Copy link
Member

masahi commented Mar 16, 2021

Thanks @mbrookhart

@mbrookhart mbrookhart deleted the mergepath_odd_even branch March 16, 2021 15:28
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request May 6, 2021
* Mergepath sort with odd-even block sort

* fix lint, add test

* respond to review comments

* speed up tests by reducing dtype skews

* fix bad rebase

* change threading to support vulkan

* fix lint

* only sort if the data is non-empty

* fix lint again

* fix for vk

* move if to higher scope

* fix typo

Co-authored-by: Masahiro Masuda <masahi@129@gmail.com>
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request May 11, 2021
* Mergepath sort with odd-even block sort

* fix lint, add test

* respond to review comments

* speed up tests by reducing dtype skews

* fix bad rebase

* change threading to support vulkan

* fix lint

* only sort if the data is non-empty

* fix lint again

* fix for vk

* move if to higher scope

* fix typo

Co-authored-by: Masahiro Masuda <masahi@129@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants